#ifndef MAPREDUCE_REDUCE_
#define MAPREDUCE_REDUCE_

#include <functional>
#include <omp.h>

#include "common"
#include "range"

namespace map_reduce {

namespace reduce_sched {
    enum base_policy {
        always = 5,
        root   = 6,
        never  = 7
    };

    template <unsigned Level = 0>
    struct parallel {
        static const base_policy when = root;
        static const unsigned level = Level;
    };

    template <unsigned Level = 0>
    struct parallel_always {
        static const base_policy when = always;
        static const unsigned level = Level;
    };

    struct serial {
        static const base_policy when = never;
        static const unsigned level = 0;
    };

    //using automatic = parallel<0>;
    using automatic = serial;
};


template <typename T>
struct reduce_ops {
    inline
    static T add(T a, T b)
    {
        return a + b;
    }

    inline
    static T mul(T a, T b)
    {
        return a * b;
    }

    inline
    static T greater_than(T a, T b)
    {
        return a > b? a: b;
    }

    inline
    static T less_than(T a, T b)
    {
        return a < b? a: b;
    }
};

template <unsigned Level, typename Access, typename Range>
struct invoker {
    template<typename... Args>
    inline
    static typename reduce_traits<Access>::return_type
    invoke(Access a, Range r, Args... args)
    {
        return invoker<Level - 1, Access, Range>::invoke(a, r, args..., r.get_dim(sizeof...(Args)).begin);
    }
};

template <typename Access, typename Range>
struct invoker<0, Access, Range> {
    template<typename... Args>
    inline
    static typename reduce_traits<Access>::return_type
    invoke(Access a, Range r, Args... args)
    {
        return a(args...);
    }
};

template <unsigned Level, typename Ret, typename Access, typename Func, typename Range, typename Policy>
struct wrapper_reduce {
    template<typename... Args>
    inline
    static Ret
    reduce(Access a, Func f, Range r, Ret tmp, Args... args)
    {
        using mydim = typename Range::dim_type;

        mydim d = r.get_dim(sizeof...(Args));

        size_t steps = (d.end - d.begin) / d.step;

        Ret ret;

        if ((Policy::when == reduce_sched::base_policy::always &&
             Policy::level == Range::NDims - Level) ||
            (Policy::when == reduce_sched::base_policy::root &&
             Policy::level == Range::NDims - Level &&
             !in_parallel)) {
            abort();

            size_t chunks;
            if (steps >= omp_get_num_procs()) {
                chunks = omp_get_num_procs();
            } else {
                chunks = 1;
            }
            size_t local_steps = steps / chunks;

            bool first = true;

            #pragma omp parallel for shared(ret, first)
            for (size_t i = 0; i < chunks; ++i) {
                //in_parallel = true;

                Ret partial;
                if (Range::NDims == (sizeof...(Args) + 1)) {
                    partial = invoker<Level - 1, Access, Range>::invoke(a, r, args..., d.begin + i * local_steps);
                }
                typename Range::type local_begin = d.begin + i * local_steps + (Range::NDims == (sizeof...(Args) + 1)? d.step: 0);
                for (typename Range::type t  = local_begin;
                                          t  < ((i == chunks - 1)? d.end: (d.begin + (i + 1) * local_steps));
                                          t += d.step) {
                    if (Range::NDims == (sizeof...(Args) + 1) ||
                        t == local_begin) {
                        partial = wrapper_reduce<Level - 1, Ret, Access, Func, Range, Policy>::reduce(a, f, r, partial, args..., t);
                    } else {
                        partial = f(partial, wrapper_reduce<Level - 1, Ret, Access, Func, Range, Policy>::reduce(a, f, r, partial, args..., t));
                    }
                }
                #pragma omp critical
                if (first) {
                    ret = partial;
                    first = false;
                } else {
                    ret = f(partial, ret);
                }
            }
        } else {
            Ret partial;
            if (Range::NDims == (sizeof...(Args) + 1)) {
                // Take the first element as the first partial value
                partial = invoker<Level - 1, Access, Range>::invoke(a, r, args..., d.begin);
            }

            // skip the first element if necessary
            typename Range::type begin = d.begin + (Range::NDims == (sizeof...(Args) + 1)? d.step: 0); 
            for (typename Range::type t  = begin;
                                      t  < d.end;
                                      t += d.step) {
                if (Range::NDims == (sizeof...(Args) + 1) ||
                    t == begin) {
                    partial = wrapper_reduce<Level - 1,
                                             Ret,
                                             Access,
                                             Func,
                                             Range,
                                             Policy>::reduce(a, f, r, partial, args..., t);
                } else {
                    partial = f(partial, wrapper_reduce<Level - 1,
                                                        Ret,
                                                        Access,
                                                        Func,
                                                        Range,
                                                        Policy>::reduce(a, f, r, partial, args..., t));
                }
            }
            ret = partial;
        }

        return ret;
    }

    inline
    static Ret
    reduce(Access a, Func f, Range r)
    {
        using mydim = typename Range::dim_type;

        mydim d = r.get_dim(0);

        size_t steps = (d.end - d.begin) / d.step;

        Ret ret;

        if ((Policy::when == reduce_sched::base_policy::always &&
             Policy::level == Range::NDims - Level) ||
            (Policy::when == reduce_sched::base_policy::root &&
             Policy::level == Range::NDims - Level &&
             !in_parallel)) {
            //printf("<R> Parallel: %u\n", NDims - Level);
            size_t chunks;
            if (steps >= omp_get_num_procs()) {
                chunks = omp_get_num_procs();
            } else {
                chunks = steps;
            }
            size_t local_steps = steps / chunks;

            bool first = true;

            #pragma omp parallel for shared(ret, first)
            for (size_t i = 0; i < chunks; ++i) {
                Ret partial;
                if (Range::NDims == 1) {
                    partial = invoker<Level - 1, Access, Range>::invoke(a, r, d.begin + i * local_steps);
                }
                typename Range::type local_begin = d.begin + i * local_steps + (Range::NDims == 1? d.step: 0);
                for (typename Range::type t  = local_begin;
                                          t  < ((i == chunks - 1)? d.end: (d.begin + (i + 1) * local_steps));
                                          t += d.step) {
                    if (Range::NDims == 1 ||
                        t == local_begin) {
                        partial = wrapper_reduce<Level - 1,
                                                 Ret,
                                                 Access,
                                                 Func,
                                                 Range,
                                                 Policy>::reduce(a, f, r, partial, t);
                    } else {
                        partial = f(partial, wrapper_reduce<Level - 1,
                                                            Ret,
                                                            Access,
                                                            Func,
                                                            Range,
                                                            Policy>::reduce(a, f, r, partial, t));
                    }
                }
                #pragma omp critical
                {
                    if (first) {
                        ret = partial;
                        first = false;
                    } else {
                        ret = f(partial, ret);
                    }
                }
            }
        } else {
            //printf("<R> NOT Parallel: %u\n", NDims - Level);
            Ret partial;
            if (Range::NDims == 1) {
                // Take the first element as the first partial value
                partial = invoker<Level - 1, Access, Range>::invoke(a, r, d.begin);
            }
            typename Range::type begin = d.begin + (Range::NDims == 1? d.step: 0); // skip the first element if necessary

            for (typename Range::type t  = begin;
                                      t  < d.end;
                                      t += d.step) {
                if (Range::NDims == 1 ||
                    t == begin) {
                    partial = wrapper_reduce<Level - 1, Ret, Access, Func, Range, Policy>::reduce(a, f, r, partial, t);
                } else {
                    partial = f(partial, wrapper_reduce<Level - 1, Ret, Access, Func, Range, Policy>::reduce(a, f, r, partial, t));
                }
            }

            ret = partial;
        }

        return ret;
    }
};

template <typename Ret, typename Access, typename Func, typename Range, typename Policy>
struct wrapper_reduce<0, Ret, Access, Func, Range, Policy> {
    template<typename... Args>
    inline
    static Ret
    reduce(Access a, Func f, const Range &/*r*/, Ret tmp, Args... args)
    {
        return f(a(args...), tmp);
    }
};

template <typename Access, typename Func, typename Range, typename Policy>
inline
static typename reduce_traits<Func>::return_type
reduce(Access a,
       Func f,
       Range r,
       const Policy &/*p*/)
{
    return wrapper_reduce<Range::NDims,
                          typename reduce_traits<Func>::return_type,
                          Access,
                          Func,
                          Range,
                          Policy>::reduce(a, f, r);
}

template <typename Access, typename Func, typename Range>
inline
static typename reduce_traits<Func>::return_type
reduce(Access a,
       Func f,
       Range r)
{
    return reduce(a, f, r, reduce_sched::automatic());
}


#define REDUCTION(name,op)                                                            \
template <typename Access, typename Range, typename Policy = reduce_sched::automatic> \
typename reduce_traits<Access>::return_type                                           \
reduce_##name (Access a,                                                              \
               Range r)                                                        \
{                                                                                     \
    return reduce(a, reduce_ops<typename reduce_traits<Access>::return_type>::op, r); \
}                                                                                     \


REDUCTION(min,  less_than)
REDUCTION(max,  greater_than)
REDUCTION(sum,  add)
REDUCTION(prod, mul)

}

#endif

/* vim:set ft=cpp backspace=2 tabstop=4 shiftwidth=4 textwidth=120 foldmethod=marker invokertab: */
